import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm 
import csv
import time
import torch.nn.functional as F
from models import LeNet, VGG16
from AL_Dataset import ActiveLearning_Framework
from train_test_model import train_model, test_model
from Query_Strategy_Baseline import RandomSampling, MaxEntropy, kCenterGreedy, BADGE
from Query_Strategy_Ours import SparseCoreset
import utils
import matplotlib.pyplot as plt
import numpy as np

# ensure all consistency
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

dataset_names = ['MNIST', 'SVHN', 'cifar10']
methods = ['RandomSampling', 'MaxEntropy', 'kCenterGreedy', 'BADGE', 'SparseCoreset']
model_names = ['LeNet5', 'VGG16']
projection_methods = ['gradient']
optimization_methods = ['greedy', 'prox_iht']

# select current settings here
dataset_name = 'cifar10'
method = 'RandomSampling'
optimization_method = 'prox_iht'
model_name = 'VGG16'
if dataset_name == 'SVHN' or dataset_name == 'cifar10':
    model = VGG16(10)
    num_seed_data = 3000
    epoch_num = 100
    batch_size = 128
    num_queried_each_round = 3000
    num_round_to_acquire = 6
    budget = num_seed_data + (num_round_to_acquire - 1) * num_queried_each_round
    num_class = 10
    projection_method = 'gradient'
    alpha = 1
    beta = 1
elif dataset_name == 'MNIST':
    model = LeNet(10) 
    num_seed_data = 40
    epoch_num = 150
    batch_size = 32
    num_queried_each_round = 40
    num_round_to_acquire = 15
    budget = num_seed_data + (num_round_to_acquire - 1) * num_queried_each_round
    num_class = 10 
    projection_method = 'gradient'
    alpha = 1
    beta = 1
else:
    raise ValueError('dataset unsupported')
print('using ' + method + ' for active learning on ' + dataset_name)
weighted = True if method == 'SparseCoreset' else False

if method == 'SparseCoreset': 
    criterion = nn.CrossEntropyLoss(reduction='none')
else:
    criterion = nn.CrossEntropyLoss()
test_criterion = nn.CrossEntropyLoss()

validation_set_ratio = 0.1
results_path = './results/' + '_'.join([dataset_name, method, model_name, str(seed), str(num_seed_data), str(batch_size), str(num_queried_each_round), str(num_round_to_acquire)]) + '.csv' 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('running on ' + str(device))
model.to(device)

# Loading Dataset
if dataset_name == 'cifar10':
    transform_train = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])    # RGB
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])    # RGB
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    ActiveLearningset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_test)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

elif dataset_name == 'SVHN':
    transform_train = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.4377, 0.4438, 0.4728),(0.1980, 0.2010, 0.1970))])
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.4524,  0.4525,  0.4690),(0.2194,  0.2266,  0.2285))])
    
    trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_train)
    ActiveLearningset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_test)
    testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test)

elif dataset_name == 'MNIST':
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.1307), (0.3081))])    # grayscale

    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    ActiveLearningset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

else:
    raise ValueError('dataset unsupported')

ALset = ActiveLearning_Framework(trainset, ActiveLearningset, testset, num_seed_data, num_class, validation_set_ratio, device)
all_results = []
results_content = ['num_training_samples', 'accuracy', 'test_loss', 'train_loss', 'acquire_time',
                  'train_time', 'test_time', 'acquired_labels', 'train_total_loss', 'train_approximation_loss', 'train_variance_loss',
                  'valid_total_loss', 'valid_approximation_loss', 'valid_variance_loss',
                  'w_min', 'w_max', 'w_mean_deviation', 'w', 'w_normalize',
                  'supp', 'selected_coreset_size', 'maximal_coreset_size',
                  'number_proj_train', 'number_proj_valid', 'alpha','beta']

# Training and testing
num_training_samples = num_seed_data
round_ = 0
while num_training_samples < num_queried_each_round*num_round_to_acquire: 
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)
    if dataset_name == 'MNIST':
        model = LeNet(10)  
    if dataset_name == 'cifar10' or dataset_name == 'SVHN':
        model = VGG16(10)
    model.to(device)
    num_training_samples = len(ALset.labeled_pool)
    if dataset_name == 'cifar10':
        optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)  
    else:
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 
    print('\n \n using {}/{} samples for training...'.format(num_training_samples, budget))

    # Training
    time_start = time.perf_counter()
    train_loss = train_model(model, criterion, test_criterion, optimizer, device, ALset, epoch_num, batch_size, weighted)
    train_time = time.perf_counter() - time_start

    # Testing
    time_start = time.perf_counter()
    accuracy, test_loss = test_model(model, test_criterion, device, ALset, batch_size)
    test_time = time.perf_counter() - time_start

    # Acquire unlabeled batch
    print('acquiring data using ' + method + ' ...')
    time_start = time.perf_counter()
    if method == 'RandomSampling':
        AL_method = RandomSampling(ALset, num_queried_each_round)
    elif method == 'MaxEntropy':
        AL_method = MaxEntropy(model, ALset, num_queried_each_round, batch_size, device)
    elif method == 'kCenterGreedy':
        AL_method = kCenterGreedy(model, ALset, num_queried_each_round, batch_size, device, num_class)
    elif method == 'BADGE':
        AL_method = BADGE(model, ALset, num_queried_each_round, batch_size, device, num_class)
    elif method == 'SparseCoreset':
        AL_method = SparseCoreset(model, ALset, num_queried_each_round, batch_size, device, num_class, projection_method, optimization_method, alpha, beta)
    else:
        raise ValueError('method not defined')

    AL_method.query()
    print('Finished acquiring unlabeled data')
    acquire_time = time.perf_counter() - time_start
    print('Acquire time', acquire_time)
    round_ += 1

    # Save results
    round_results = {'num_training_samples': num_training_samples, 'accuracy': accuracy, 'test_loss': test_loss,
                    'train_loss': train_loss, 'acquire_time': acquire_time,
                    'train_time': train_time,
                    'test_time': test_time, 'acquired_labels': ALset.get_labels_just_moved_stats(),
                    }
    if method == 'SparseCoreset':
        for key in ALset.optim_results:
            round_results[key] = ALset.optim_results[key]
    print('acquired_labels', ALset.get_labels_just_moved_stats(), '\n')
    all_results.append(round_results)
    try:
        with open(results_path, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=results_content)
            writer.writeheader()
            for result in all_results:
                writer.writerow(result)
    except IOError:
        print("I/O error")

print('Finished!')

